#!/usr/bin/env python
# coding: utf-8


#Import libraries
import pandas as pd
import numpy as np
import sys
import pysam
import re
import matplotlib.pyplot as plt
from collections import defaultdict
import logging
import os
import contextlib


# Capture the command-line arguments passed in shell script
args = sys.argv[1:]


#Import Bash Variables
sample_name = args[0]
outdir = args[1]
fasta_filename = args[2]
contig_file = args[3]
discard_deletions = args[4]
experiment_condition = args[5] # if 'Yes', then this function is being run on an experimental library and not a plasmid library
desired_contig = args[6] # will be used to output dataframe of counts



#Convert Aligned file into Mutations Text file 
bam_filename = outdir + "/" + sample_name + "/aligned.bam"
output_filename = outdir + "/" + sample_name + "/output_positions_Seq.txt"
reads_without_cs = outdir + "/" + sample_name + "/reads_without_cs_tag.log"
log_file = outdir + "/" + sample_name + "/log.log"
text_filename = outdir + "/" + sample_name + "/text.txt"
discard_text_filename = outdir + "/" + sample_name + "/discard_text.txt" #For reads that don't start or finish their alignment at the beginning/end of the contig
plot_path = outdir + "/" + sample_name
summary_path  = outdir + "/" + sample_name + "/summary.txt"


#Define ContigData class to serve as a contained for quickly tabulating mutations
class ContigData:
    def __init__(self):
        self.data = {}
        # Initialize the 'No_Contig' category right away
        self.data['No_Contig'] = {
            'No_CS': 0  # For reads not aligning to known contigs and without a 'cs' tag
        }
    

    def add_contig(self, contig_name, contig_length=None):
        # Initialize dictionaries for the new contig
        if contig_name not in self.data:
            self.data[contig_name] = {
                'substitutions': {},
                'deletions': {},
                'insertions': 0,
                'ambiguous_deletions': 0,
                'WT': 0,
                'No_CS': 0,  # Adding here the initialization for 'No_CS'
                'discarded': 0,  # Adding the counter for discarded reads
                'Incomplete_Alignment': 0,  # Initialize the 'Incomplete_Alignment' counter
                'length': contig_length  # Initialize the length of the contig
            }
    
    def update_incomplete_alignment(self, contig_name):
        if contig_name not in self.data:
            self.data[contig_name] = {}
        self.data[contig_name]['Incomplete_Alignment'] = self.data[contig_name].get('Incomplete_Alignment', 0) + 1

    def update_unaligned_read(self):
        # A simple counter for reads not aligned to known contigs
        self.data['No_Contig']['No_CS'] += 1

    def initialize_contigs(self, contig_ranges, consecutive_positions):
        for contig, (start, end, contig_length) in contig_ranges.items():
            self.add_contig(contig, contig_length=contig_length)
            
            # Prepare the set of positions to exclude
            exclude_positions = set(consecutive_positions.get(contig, []))

            for position in range(start, end + 1):
                if position not in exclude_positions:
                    self.data[contig]['deletions'][position] = 0  # Initializes the count
                self.data[contig]['substitutions'][position] = 0  # Initializes the count regardless


    def update_substitution(self, contig_name, position, info):
        # Ensure the contig is initialized
        if contig_name not in self.data:
            self.add_contig(contig_name)

        substitutions = self.data[contig_name]['substitutions']
        if position in substitutions:
            # Increment the count of substitutions at this position
            substitutions[position] += 1  # incrementing instead of appending
        else:
            # Handle the case where the position is not initialized. This shouldn't happen if 'initialize_contigs' is used correctly.
            substitutions[position] = 1

    def update_deletion(self, contig_name, position):
        # Similar structure to `update_substitution`
        if contig_name not in self.data:
            self.add_contig(contig_name)

        deletions = self.data[contig_name]['deletions']
        if position in deletions:
            # Increment the count of deletions at this position
            deletions[position] += 1  # incrementing instead of appending
        else:
            # Handle the case where the position is not initialized. This shouldn't happen if 'initialize_contigs' is used correctly.
            deletions[position] = 1
        
    def update_simple_count(self, contig_name, mutation_type):
        # This method is updated to handle 'No_CS' within known contigs and 'No_Contig'
        if contig_name not in self.data:
            if contig_name == 'No_Contig':
                # We don't need to call add_contig for 'No_Contig', but we'll handle it here
                self.update_unaligned_read()
                return
            else:
                self.add_contig(contig_name)

        # For 'insertions', 'ambiguous_deletions', 'wt', and 'No_CS' (if within a known contig)
        self.data[contig_name][mutation_type] += 1

    def update_discarded(self, contig_name):
        # Increment the counter for discarded reads
        if contig_name not in self.data:
            self.add_contig(contig_name)

        self.data[contig_name]['discarded'] += 1

    def get_data(self):
        return self.data


#Define all functions


def parse_cs_tag(cs_tag):
    """
    Parse the cs tag generated by sequence alignment tools like minimap2, detailing the differences 
    between a sequence and a reference sequence. The function identifies matched sequences, 
    substitutions, deletions, and insertions.

    Parameters:
    cs_tag (str): The cs string from the alignment file, e.g., "cs:Z::10-ata:5*ag:3+t:7".

    Returns:
    tuple: A tuple containing two lists and an integer: 
           - positions (list): The positions in the sequence where mutations occur.
           - bases (list): A record of the type of mutations (substitution bases, 'Deletion', or 'Insertion').
           - deletion (int): [REMOVED IN THIS VERSION]

    Example usage:
    positions, bases = parse_cs_tag("cs:Z::10-ata:5*ag:3+t:7")

    Examples of cs tag segments and function behavior:
    1. "::10" indicates that the first 10 bases match the reference. The function does not record positions or bases.
    2. "-ata" suggests a deletion. The "ata" in the reference is not in the sequence. The function records positions and marks them as 'Deletion'.
    3. "*ag" implies a substitution; 'a' in the reference is 'g' in the sequence. The function records the position and 'g'.
    4. "+t" indicates an insertion. The sequence has an extra 't' not in the reference. The function records the position before 't' and marks it as 'Insertion'.

    For "cs:Z::10-ata:5*ag:3+t:7", the function returns:
    positions = [11, 12, 13, 19, 23]
    bases = ['Deletion', 'Deletion', 'Deletion', 'g', 'Insertion']
    """

    positions = []
    bases = []
    position = 1  # Start with a position counter set to 1

    # Split the cs tag into segments that represent different types of matches/mutations.
    cs_segments = re.split(r':|(?=\*|[+\-])', cs_tag)[1:]  # Adjusted to handle cases without numbers before mutations

    for cs_segment in cs_segments:
        # Matches: Sequence of bases that are identical between the sequence and the reference
        if cs_segment.isdigit():
            position += int(cs_segment)  # Advance the position counter by the number of matching bases
            continue

        # Substitutions: Bases in the sequence that replace those in the reference
        if '*' in cs_segment:
            substitution_base = cs_segment[2]  # Get the base that occurs in the sequence (after '*')
            bases.append(substitution_base)  # Record the substitution base
            positions.append(position)  # Record the position of the substitution
            position += 1  # Advance the position counter by 1
            continue

        # Deletions: Bases present in the reference but missing in the sequence
        if '-' in cs_segment:
            deletion_bases = re.search(r'-(\w+)', cs_segment).group(1)  # Extract the deleted bases
            deletion_length = len(deletion_bases)  # Determine the number of bases deleted
            positions.extend(range(position, position + deletion_length))  # Record the positions of the deletions
            bases.extend(['Deletion'] * deletion_length)  # Mark each deleted base as 'Deletion'
            position += deletion_length  # Advance the position counter by the number of bases deleted
            continue

        # Insertions: Extra bases in the sequence that are not in the reference
        if '+' in cs_segment:
            # (Note: The position before the insertion is recorded, as the extra base does not align with reference positions)
            bases.append('Insertion')  # Mark the additional base as 'Insertion'
            positions.append(position)  # Record the position occurring before the insertion
            # (Note: The position counter is not advanced here because insertions do not correspond to reference positions)

    return positions, bases  # Return the collected positions and corresponding mutation types




def extract_positions_from_bam(input_bam_filename, output_filename, no_cs_tag_filename=None):
    """
    Extracts relevant information from each read in a BAM file, focusing on the 'cs' tag, and writes details to an output file.
    Reads without a 'cs' tag are optionally logged in a separate file.

    Parameters:
    input_bam_filename (str): Path to the input BAM file. The file should be in BAM format.
    output_filename (str): Path to the output file where details of reads with a 'cs' tag will be written. This includes mutation positions, bases, and related information.
    no_cs_tag_filename (str, optional): If provided, the function will write the details of reads without a 'cs' tag to this file. Each entry in this log file indicates a read lacking a 'cs' tag. If None, these reads will not be logged, and the script will only process reads with a 'cs' tag.

    Returns:
    None: The function does not return any values. It writes information directly to the specified output files.

    Output File Format:
    The main output file contains tab-separated columns with the following information:
    1. Query Name: The name of the read.
    2. Reference Name: The name of the reference sequence that the read aligns against.
    3. Positions: A list of positions within the read that have mutations. These positions correspond with the 'cs' tag's data, reflecting substitutions, deletions, or insertions. They are not inclusive of matches.
    4. Bases: Describes the bases at the mutation positions. This could be the actual base (in case of substitution) or a descriptor such as 'Insertion' or 'Deletion'.
    5. Query Sequence: The full sequence of the read.
    6. Read Length: The total length of the read sequence.
    7. Start Position: The starting position of the read in the reference sequence (0-based).

    The optional 'no_cs_tag' log file contains tab-separated columns with the following information:
    1. Query Name: The name of the read.
    2. Reference Name: The name of the reference sequence.
    3. Remarks: A note indicating that this particular read does not contain a 'cs' tag.

    Example Usage:
    extract_positions_from_bam("sample.bam", "output_details.txt", "reads_without_cs_tag.log")
    """

    with pysam.AlignmentFile(input_bam_filename, "rb") as bamfile, \
         open(output_filename, "w") as output_file:

            # Conditional file handling for no_cs_tag_filename
            no_cs_tag_context = open(no_cs_tag_filename, "w") if no_cs_tag_filename else contextlib.nullcontext()
            with no_cs_tag_context as no_cs_tag_file:
                for read in bamfile:
                    if read.has_tag("cs"):
                        cs_tag = read.get_tag("cs")
                        positions, bases = parse_cs_tag(cs_tag)  # Adjusted to reflect the correct function return values
                        read_len = len(read.query_sequence)

                        # Writing the read's details to the output file. 'deletion' field is removed.
                        output_file.write(
                            f"{read.query_name}\t{read.reference_name}\t{positions}\t{bases}\t{read.query_sequence}\t{read_len}\t{read.reference_start}\n"
                        )
                    else:
                        # If no_cs_tag_filename is provided, log reads without a 'cs' tag.
                        if no_cs_tag_filename:
                            no_cs_tag_file.write(f"{read.query_name}\t{read.reference_name}\tNo cs tag\n")


def analyze_fasta(fasta_filename, contig_ranges=None):
    """
    Analyze a FASTA file to identify and return positions of consecutive identical bases within the sequences.
    The function also filters these positions based on provided contig ranges and converts all sequences to uppercase.

    Parameters:
    fasta_filename (str): Path to the FASTA file.
    contig_ranges (dict, optional): Dictionary containing contig names as keys and tuples with range start and end positions as values.
                                    Positions outside these ranges are ignored. If None, no filtering is applied.

    Returns:
    dict: A dictionary containing sequence names as keys and lists of positions with consecutive identical bases as values.
    """

    sequences = {}  # Dictionary to store sequence data

    # Read the FASTA file, store sequences, and convert them to uppercase
    with open(fasta_filename, 'r') as file:
        current_sequence = None
        for line in file:
            line = line.strip()
            if line.startswith('>'):  # Identifier line
                current_sequence = line[1:]
                if current_sequence in sequences:
                    raise ValueError(f"Duplicate sequence identifier found: {current_sequence}")
                sequences[current_sequence] = ''
            elif current_sequence is not None:  # Sequence line
                # Convert the sequence to uppercase and accumulate
                sequences[current_sequence] += line.upper()
            else:
                raise ValueError("File format error: Data encountered before sequence identifier.")

    consecutive_positions_dict = {}  # Nested dictionary to store consecutive positions

    for sequence_name, sequence in sequences.items():
        sequence_positions = []
        prev_base = None
        current_streak = []  # Hold the positions of the current streak of identical bases

        # Processing each base in sequence to find consecutive identical bases
        for pos, base in enumerate(sequence, start=1):  # Start at position 1
            # No need for base.upper() here anymore as sequences are already in uppercase

            # If same base as previous, we continue the streak
            if base == prev_base:
                current_streak.append(pos)
            else:
                # If streak ends (change of base), and it had more than one base, we store the positions
                if len(current_streak) > 1:
                    sequence_positions.extend(current_streak)
                current_streak = [pos]  # New streak starts with the current base
                prev_base = base

        # Don't forget the last streak in the sequence
        if len(current_streak) > 1:
            sequence_positions.extend(current_streak)

        # If contig_ranges is provided, we filter the positions based on the range
        if contig_ranges and sequence_name in contig_ranges:
            range_start, range_end = contig_ranges[sequence_name]
            sequence_positions = [pos for pos in sequence_positions if range_start <= pos <= range_end]

        consecutive_positions_dict[sequence_name] = sequence_positions

    return consecutive_positions_dict


def analyze_mutations(positions, bases, valid_substitution_positions, valid_deletion_positions, homopolymer_positions, discard_deletions=False):
    """
    Analyze the positions and bases to categorize the type of mutations, identifying the occurrence of perfect matches,
    insertions, deletions (both valid and ambiguous), and substitutions. This function also incorporates logic to decide 
    whether to discard a read entirely based on the presence of deletions and a flag indicating such preference.

    Parameters:
    positions (list): A list of integers representing the positions in the read where mutations occur.
    bases (list): A list of strings representing the types of mutations (e.g., 'Insertion', 'Deletion') at each position in 'positions'.
    valid_substitution_positions (set): A set of integers representing positions where substitutions are considered valid.
    valid_deletion_positions (set): A set of integers representing positions where deletions are considered valid.
    homopolymer_positions (set): A set of integers representing positions that are part of homopolymer regions.
    discard_deletions (bool): Optional; A flag indicating whether reads with deletions should be discarded. Defaults to False.

    Returns:
    dict: A dictionary containing detailed categorizations of mutations including a flag indicating if the read should be discarded.
    """
    mutation_details = {
        'perfect_match': False,
        'insertions': [],
        'ambiguous_deletions': [],
        'deletions': [],
        'substitutions': [],
        'discard': False  # Additional flag to indicate if the read should be discarded
    }

    # If there are no mutations (this depends on what parse_cs_tag returns when there's no mutation)
    if not positions and not bases:
        mutation_details['perfect_match'] = True
        return mutation_details

    for position, base in zip(positions, bases):
        if base == 'Deletion':
            # New logic to handle the discard flag for deletions
            if discard_deletions:
                mutation_details['discard'] = True
                return mutation_details  # Early exit, as we're discarding this read
            elif position in homopolymer_positions:
                # If the deletion is within a homopolymer region, it's ambiguous
                mutation_details['ambiguous_deletions'].append(position)
            elif position in valid_deletion_positions:
                # It's a regular deletion and not in a homopolymer region
                mutation_details['deletions'].append(position)
        elif base == 'Insertion':
            mutation_details['insertions'].append(position)
        else:
            # For substitutions, we check if they're in the valid positions
            if position in valid_substitution_positions:
                mutation_details['substitutions'].append((position, base))

    # Check if we found any significant mutations. If none, mark as perfect match.
    if not any(mutation_details.values()):  # Check if all categories are empty
        mutation_details['perfect_match'] = True

    return mutation_details


def extract_and_analyze_reads(input_bam_filename, contig_data, consecutive_positions, output_text=False, text_filename=None, discard_deletions=False, discard_text_filename=None):
    """
    Extracts and analyzes reads from a BAM file, focusing on the 'cs' tag. Optionally outputs details to text files.

    Parameters:
    input_bam_filename (str): Path to the input BAM file. The file should be in BAM format.
    contig_data (ContigData): An instance of the ContigData class for storing read analysis results.
    consecutive_positions (dict): A dictionary containing information about homopolymer regions for each contig.
    output_text (bool): If True, the function will output a text file with reads' details.
    text_filename (str): The path where the text file will be saved if output_text is True.
    discard_deletions (bool): If True, the function will output a separate file for reads treated as indels.
    discard_text_filename (str): The path where the discard text file will be saved if discard_deletions is True.

    Returns:
    None: The function performs analysis and updates the contig_data instance with the results. Optionally writes to text files.
    """
    print("Starting analysis...")

    logger.info("Initiating the BAM file processing.")
    logging.info(f"Starting analysis on {input_bam_filename}")

    try:
        if output_text:
            with open(text_filename, 'w') as text_file:
                headers = ["Read_Name", "Contig", "Bases", "Positions", "ReadLength", "Sequence", "Tabulated", "Start_Pos"]
                text_file.write('\t'.join(headers) + '\n')
                
                # Open the discard file only if discard_deletions is True
                discard_file = open(discard_text_filename, 'w') if discard_deletions and discard_text_filename else None
                if discard_file:
                    discard_headers = ["Read_Name", "Contig", "Sequence", "Start_Pos", "End_Pos"]
                    discard_file.write('\t'.join(discard_headers) + '\n')
                
                with pysam.AlignmentFile(input_bam_filename, "rb") as bamfile:
                    # Process the bamfile and potentially write details to text_file and discard_file
                    process_bamfile(bamfile, contig_data, consecutive_positions, text_file=text_file, discard_file=discard_file, discard_deletions=discard_deletions)
                
                # Close the discard file if it was opened
                if discard_file:
                    discard_file.close()
                    
        else:
            with pysam.AlignmentFile(input_bam_filename, "rb") as bamfile:
                # Process the bamfile without writing to text files
                process_bamfile(bamfile, contig_data, consecutive_positions)
                
    except FileNotFoundError:
        logging.error(f"File not found: {input_bam_filename}")
    except ValueError as e:
        logging.error(f"Data processing error in {input_bam_filename}: {e}")
    except Exception as e:
        logging.error(f"Unexpected error in {input_bam_filename}: {e}", exc_info=True)
    finally:
        logging.info(f"Completed analysis on {input_bam_filename}")

        
        
        
        
def process_bamfile(bamfile, contig_data, consecutive_positions, text_file=None, discard_file=None, discard_deletions=False):
    """
    Processes reads from an opened BAM file and performs analysis based on the 'cs' tag.

    Parameters:
    bamfile (AlignmentFile): An opened BAM file ready for reading.
    contig_data (ContigData): An instance of the ContigData class for storing read analysis results.
    consecutive_positions (dict): A dictionary containing information about homopolymer regions for each contig.
    text_file (file): An optional file handle for writing read details. If None, no details are written.

    Returns:
    None: The function performs analysis and updates the contig_data instance with the results.
    """
    for read in bamfile:
        read_name = read.query_name  # Extracting the read name
        tabulated = "No"  # Default to "No"
        contig = read.reference_name if read.reference_name else "No_Contig"
        read_len = len(read.query_sequence)
        start_position = read.reference_start  # This is 0-based position
        end_position = read.reference_end #This should be 1-based already

        # Get contig information
        contig_info = contig_data.get_data().get(contig)
        if contig_info:
            contig_length = contig_info.get('length')  # Define contig_length here

        # Check read length against contig length
        if discard_deletions and contig_info:
            if start_position != 0 or end_position != contig_length:
                contig_data.update_incomplete_alignment(contig)  # Update the 'Indel' counter
                
                # Write to the discard file if it's open
                if discard_file:
                    read_data = [
                        read_name,
                        contig,
                        read.query_sequence,
                        str(start_position),
                        str(end_position)
                    ]
                    discard_file.write('\t'.join(read_data) + '\n')

                continue  # Skip further processing

        # Check if the read has a 'cs' tag, if not, update the count and skip further processing
        if not read.has_tag("cs"):
            contig_data.update_simple_count(contig, 'No_CS')
            #if text_file and contig == "No_Contig":
                #text_file.write(f"Read without 'cs' tag found in No_Contig\n")
            continue

        # Retrieve the 'cs' tag value
        cs_tag = read.get_tag("cs")

        # Get the necessary contig information for mutation analysis
        contig_info = contig_data.get_data().get(contig)
        valid_substitution_positions = contig_info.get('substitutions', set())
        valid_deletion_positions = contig_info.get('deletions', set())
        homopolymer_positions = consecutive_positions.get(contig, set()) 

        # Parse the 'cs' tag and analyze the mutations
        positions, bases = parse_cs_tag(cs_tag)

        mutation_details = analyze_mutations(
            positions, 
            bases, 
            valid_substitution_positions, 
            valid_deletion_positions, 
            homopolymer_positions
        )

        #Discard read if analysis is being run with no-deletion setting
        if mutation_details['discard']:
            contig_data.update_discarded(contig)  # Update the discard counter
            continue  # Skip further processing and move to the next read

        # Handle perfect match scenario
        if mutation_details['perfect_match']:
            contig_data.update_simple_count(contig, 'WT')
            tabulated = "Yes"

        # Handle insertions
        if mutation_details['insertions']:
            contig_data.update_simple_count(contig, 'insertions')
            skip_to_writing = True

        # Handle ambiguous deletions
        if mutation_details['ambiguous_deletions']:
            contig_data.update_simple_count(contig, 'ambiguous_deletions')

        # Handle deletions
        for position in mutation_details['deletions']:
            contig_data.update_deletion(contig, position)
            tabulated = "Yes"

        # Handle substitutions
        for position, base in mutation_details['substitutions']:
            contig_data.update_substitution(contig, position, base)
            tabulated = "Yes"


        if all(value is not None for value in [read_name, contig, bases, positions, read_len, read.query_sequence, tabulated, start_position, end_position]):
            # Prepare the data to write
            read_data = [
                read_name,
                contig,
                ','.join(bases),  # Convert list of bases to a comma-separated string
                ','.join(map(str, positions)),  # Convert list of positions to a comma-separated string
                str(read_len),  # Convert read length to string
                read.query_sequence,  # Actual sequence
                tabulated,  # Tabulated status
                str(start_position),
                str(end_position)  # Start position of the read
            ]

            # Write to the file, separating fields with tabs and adding a newline at the end
            text_file.write('\t'.join(read_data) + '\n')
        else:
            print(f"Data integrity check failed for read: {read_name}")




def extract_contig_data(contig_data):
    """
    Extract data from the ContigData object and structure it as a list of dictionaries.

    :param contig_data: ContigData object containing the mutation data.
    :return: List of dictionaries containing structured mutation data.
    """
    extracted_data = []

    # Retrieve the raw data dictionary from the ContigData object
    raw_data = contig_data.get_data()

    for contig, contig_info in raw_data.items():
        for mutation_type, positions in contig_info.items():
            if mutation_type in ['substitutions', 'deletions']:
                for position, count in positions.items():
                    # For each position, create a dictionary with the relevant data
                    extracted_data.append({
                        'Contig': contig,
                        'Position': position,
                        'Mutation': mutation_type,
                        'Count': count
                    })
            elif mutation_type in ['insertions', 'ambiguous_deletions', 'WT', 'No_CS']:
                # These are global counts for the contig and not position-specific
                extracted_data.append({
                    'Contig': contig,
                    'Mutation': mutation_type,
                    'Count': positions  # Here, 'positions' is actually the count for these mutation types
                })

    return extracted_data


def summarize_mutation_data(contig_data, output_file_path):
    """
    Summarize the mutation data for all contigs and save it to a text file.

    Parameters:
    contig_data (ContigData): The data container with all the mutation information.
    output_file_path (str): The path to the file where the summary will be saved.

    Returns:
    None: The function writes the summary to a file.
    """
    data = contig_data.get_data()  # Retrieve the data from your ContigData object

    # Open the output file in write mode
    with open(output_file_path, 'w') as outfile:
        # Write the header for the CSV
        headers = [
            'Contig', 'Total_Reads', 'Insertions', 'Deletions', 
            'Substitutions', 'Ambiguous_Deletions', 'WT', 'No_CS'
        ]
        outfile.write(','.join(headers) + '\n')

        # Iterate over each contig and collect the necessary data
        for contig, contig_info in data.items():
            # Calculate the total reads for the contig (assuming 'No_CS' is the only non-mutation-related count)
            total_reads = sum([count for key, count in contig_info.items() if isinstance(count, int) and key != 'No_CS'])

            # Safely get the total counts for insertions, deletions, and substitutions
            insertions = contig_info.get('insertions', 0)
            deletions = sum(contig_info.get('deletions', {}).values())
            substitutions = sum(contig_info.get('substitutions', {}).values())

            # Extract other counts, ensuring we default to 0 if the key is not found
            ambiguous_deletions = contig_info.get('ambiguous_deletions', 0)
            wt = contig_info.get('WT', 0)
            no_cs = contig_info.get('No_CS', 0)

            # Prepare the line of data to be written to the file
            line_data = [
                contig,
                str(total_reads),
                str(insertions),
                str(deletions),
                str(substitutions),
                str(ambiguous_deletions),
                str(wt),
                str(no_cs)
            ]
            outfile.write(','.join(line_data) + '\n')

# Define a function to safely perform the division.
def safe_divide(row):
    numerator = row['Count']
    denominator = wt_counts.loc[row['Contig']]['Count']
    
    # Check if the denominator is zero or NaN.
    if denominator == 0 or np.isnan(denominator):
        return 0.0  # or np.nan, depending on how you want to handle this case
    else:
        return np.divide(numerator, denominator)  # Safe division


############# Execute Code #############

#Process Contig_Ranges

contig_ranges = {}

# Read the text file line by line
with open(contig_file, 'r') as file:
    for line in file:
        line = line.strip()  # Remove leading/trailing whitespace and newline characters
        if line:
            parts = line.split()  # Split the line into parts using whitespace as delimiter
            contig_name = parts[0]
            contig_length = int(parts[3])

            # If discard_deletions is True, set the range to be from 1 through contig_length
            if discard_deletions:
                contig_ranges[contig_name] = (1, contig_length, contig_length)
            else:
                start_position = int(parts[1])
                end_position = int(parts[2])
                contig_ranges[contig_name] = (start_position, end_position, contig_length)


## Execute analyze_fasta function
consecutive_positions = analyze_fasta(fasta_filename)

#Initialize ContigData Storage
contigs = ContigData()
contigs.initialize_contigs(contig_ranges, consecutive_positions)


#Setup Logging
logging.basicConfig(
    filename=log_file,  # Make sure this path is accessible in your Jupyter environment
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)


# Create file handler which logs messages to a file
logger = logging.getLogger()  # Creating a root logger


#########Run Main Functions
try:
    # Call your function with necessary arguments
    extract_and_analyze_reads(
        input_bam_filename=bam_filename, 
        contig_data=contigs, 
        consecutive_positions=consecutive_positions, 
        output_text=True, 
        text_filename=text_filename,
        discard_deletions=discard_deletions,
        discard_text_filename=discard_text_filename
    )


except Exception as e:
    logging.error(f"An error occurred: {e}", exc_info=True)
    # You could handle specific exceptions or re-raise them if you want the notebook to stop execution.

# Get data from your ContigData object (assuming it's named 'contigs')
data_list = extract_contig_data(contigs)

# Convert the list of dictionaries into a Pandas DataFrame
data_df = pd.DataFrame(data_list)

#Make plots
# Define a directory where plots will be saved
if not os.path.exists(plot_path):
    os.makedirs(plot_path)  # Create directory if it doesn't exist

# Assuming 'data_df' is your DataFrame containing the data
mutation_types = ['deletions', 'substitutions']

# Calculate the WT counts separately as we use it for rate calculation repeatedly
wt_counts = data_df[data_df['Mutation'] == 'WT'][['Contig', 'Count']].set_index('Contig')

# First, let's handle the individual mutation plots
for mutation in mutation_types:
    # Filter data for the current mutation type
    mutation_data = data_df[data_df['Mutation'] == mutation]

    # Process the mutation data
    grouped_mutations = mutation_data.groupby(['Contig', 'Position']).sum(numeric_only=True).reset_index()

    # Calculate the rate
    grouped_mutations['Rate'] = grouped_mutations.apply(safe_divide, axis=1)
    #grouped_mutations['Rate'] = grouped_mutations.apply(
        #lambda row: row['Count'] / wt_counts.loc[row['Contig']]['Count'], axis=1)

    # Plotting for each contig, for the current mutation type
    for contig in grouped_mutations['Contig'].unique():
        contig_data = grouped_mutations[grouped_mutations['Contig'] == contig]

        plt.figure(figsize=(10, 6))
        plt.plot(contig_data['Position'], contig_data['Rate'], marker='o', linestyle='-', color='b', label=f"{mutation.capitalize()} Rate")
        plt.title(f'{mutation.capitalize()} Mutation Rate in {contig}')
        plt.xlabel('Position')
        plt.ylabel('Rate')
        plt.grid(True)
        plt.legend()

        # Save the figure
        filename = f"{contig}_{mutation}_rate.png"
        file_path = os.path.join(plot_path, filename)

        plt.savefig(file_path)
        plt.close()  # Close the figure to free up memory

# Now, for the combined plot
all_mutations = pd.concat([data_df[data_df['Mutation'] == m] for m in mutation_types])
grouped_all_mutations = all_mutations.groupby(['Contig', 'Position']).sum(numeric_only=True).reset_index()


grouped_all_mutations['Rate'] = grouped_all_mutations.apply(safe_divide, axis=1)

#grouped_all_mutations['Rate'] = grouped_all_mutations.apply(
    #lambda row: np.divide(
        #row['Count'], 
        #wt_counts.loc[row['Contig']]['Count'],
        #out=np.zeros_like(np.array(row['Count'], dtype=float)),  # providing a template for the 'out' parameter
        #where=wt_counts.loc[row['Contig']]['Count']!=0  # ensuring we don't divide when the denominator is zero
    #),
    #axis=1
#)

# Plotting for each contig for the combined data
for contig in grouped_all_mutations['Contig'].unique():
    contig_data = grouped_all_mutations[grouped_all_mutations['Contig'] == contig]

    plt.figure(figsize=(10, 6))
    plt.plot(contig_data['Position'], contig_data['Rate'], marker='o', linestyle='-', color='r', label="Combined Rate")
    plt.title(f'Combined Mutation Rate in {contig}')
    plt.xlabel('Position')
    plt.ylabel('Rate')
    plt.grid(True)
    plt.legend()

    # Save the figure
    filename = f"{contig}_combined_mutation_rate.png"
    file_path = os.path.join(plot_path, filename)

    plt.savefig(file_path)
    plt.close()  # Close the figure to free up memory



# Use the function with your ContigData object and the path where you want to save the summary
summarize_mutation_data(contigs, summary_path)




